#!/usr/bin/env python

import os
import json
import pickle
import sys
import traceback
import pandas as pd
import datetime
from pathlib import Path
import logging
import torch
import argparse
import glob
import time
import random
import re
from itertools import chain
from string import punctuation
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup,
)

logger = logging.getLogger(__name__)

run_start_time = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S")

channel_name = "training"

prefix = "/opt/ml/"
input_path = prefix + "input/data"  # opt/ml/input/data
code_path = prefix + "code"  # opt/ml/code


output_path = os.path.join(prefix, "output")  # opt/ml/output
model_path = os.path.join(prefix, "model")  # opt/ml/model

training_config_path = os.path.join(input_path, "{}/config".format(channel_name))

hyperparam_path = os.path.join(prefix, "input/config/hyperparameters.json")
config_path = os.path.join(training_config_path, "training_config.json")


# This algorithm has a single channel of input data called 'training'. Since we run in
# File mode, the input files are copied to the directory specified here.

training_path = os.path.join(input_path, channel_name)  # opt/ml/input/data/training


def searching_all_files(directory: Path):
    file_list = []  # A list for storing files existing in directories

    for x in directory.iterdir():
        if x.is_file():
            file_list.append(str(x))
        else:
            file_list.append(searching_all_files(x))

    return file_list


# The function to execute the training.
def train():

    print("Starting the training.")

    try:
        print(config_path)
        with open(config_path, "r") as f:
            training_config = json.load(f)
            print(training_config)

        with open(hyperparam_path, "r") as tc:
            hyperparameters = json.load(tc)
            print(hyperparameters)

        args_dict = dict(
            data_dir=training_path,
            output_dir=model_path,
            model_name_or_path=training_config["model_type"],
            tokenizer_name_or_path=training_config["model_type"],
            max_seq_length=int(hyperparameters["max_seq_length"]),
            learning_rate=float(hyperparameters["lr"]),
            weight_decay=0.0,
            adam_epsilon=1e-8,
            warmup_steps=int(hyperparameters["warmup_steps"]),
            train_batch_size=int(hyperparameters["train_batch_size"]),
            eval_batch_size=int(hyperparameters["train_batch_size"]),
            num_train_epochs=int(hyperparameters["epochs"]),
            gradient_accumulation_steps=int(training_config["grad_accumulation_steps"]),
            n_gpu=torch.cuda.device_count(),
            early_stop_callback=False,
            fp_16=(
                training_config["fp16"] == "True"
            ),  # if you want to enable 16-bit training then install apex and set this to true
            opt_level=training_config[
                "fp16_opt_level"
            ],  # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
            max_grad_norm=0.5,  # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
            seed=42,
        )

        logger.info("Number of GPUs: {}".format(torch.cuda.device_count()))
        args = argparse.Namespace(**args_dict)

        # checkpoint_callback = pl.callbacks.ModelCheckpoint(
        #     filepath=args.output_dir,
        #     prefix="checkpoint",
        #     monitor="val_loss",
        #     mode="min",
        #     save_top_k=5,
        # )

        set_seed(args.seed)

        train_params = dict(
            accumulate_grad_batches=args.gradient_accumulation_steps,
            gpus=args.n_gpu,
            max_epochs=args.num_train_epochs,
            early_stop_callback=False,
            precision=16 if args.fp_16 else 32,
            amp_level=args.opt_level,
            gradient_clip_val=args.max_grad_norm,
            checkpoint_callback=None,
            callbacks=[LoggingCallback()],
        )

        if args.n_gpu > 1:
            train_params["distributed_backend"] = "ddp"

        print("Initialize model")
        model = T5FineTuner(args)

        trainer = pl.Trainer(**train_params)

        print(" Training model")
        trainer.fit(model)

        print("training finished")

        print("Saving model")
        model.model.save_pretrained(args.output_dir)

        print("Saved model")

    except Exception as e:
        # Write out an error file. This will be returned as the failureReason in the
        # DescribeTrainingJob result.
        trc = traceback.format_exc()
        with open(os.path.join(output_path, "failure"), "w") as s:
            s.write("Exception during training: " + str(e) + "\n" + trc)
        # Printing this causes the exception to be in the training job logs, as well.
        print("Exception during training: " + str(e) + "\n" + trc, file=sys.stderr)
        # A non-zero exit code causes the training job to be marked as Failed.
        sys.exit(255)


class T5FineTuner(pl.LightningModule):
    def __init__(self, hparams):
        super(T5FineTuner, self).__init__()
        self.hparams = hparams

        self.model = T5ForConditionalGeneration.from_pretrained(
            hparams.model_name_or_path
        )
        self.tokenizer = T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path)

    def is_logger(self):
        return self.trainer.proc_rank <= 0

    def forward(
        self,
        input_ids,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        lm_labels=None,
    ):
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            lm_labels=lm_labels,
        )

    def _step(self, batch):
        lm_labels = batch["target_ids"]
        lm_labels[lm_labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            lm_labels=lm_labels,
            decoder_attention_mask=batch["target_mask"],
        )

        loss = outputs[0]

        return loss

    def training_step(self, batch, batch_idx):
        loss = self._step(batch)

        tensorboard_logs = {"train_loss": loss}
        return {"loss": loss, "log": tensorboard_logs}

    def training_epoch_end(self, outputs):
        avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
        tensorboard_logs = {"avg_train_loss": avg_train_loss}
        return {
            "avg_train_loss": avg_train_loss,
            "log": tensorboard_logs,
            "progress_bar": tensorboard_logs,
        }

    def validation_step(self, batch, batch_idx):
        loss = self._step(batch)
        return {"val_loss": loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        tensorboard_logs = {"val_loss": avg_loss}
        return {
            "avg_val_loss": avg_loss,
            "log": tensorboard_logs,
            "progress_bar": tensorboard_logs,
        }

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"

        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [
                    p
                    for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate,
            eps=self.hparams.adam_epsilon,
        )
        self.opt = optimizer
        return [optimizer]

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        second_order_closure=None,
        on_tpu=False,
        using_native_amp=False,
        using_lbfgs=False,
    ):

        optimizer.step()
        optimizer.zero_grad()
        self.lr_scheduler.step()

    def get_tqdm_dict(self):
        tqdm_dict = {
            "loss": "{:.3f}".format(self.trainer.avg_loss),
            "lr": self.lr_scheduler.get_last_lr()[-1],
        }

        return tqdm_dict

    def train_dataloader(self):
        train_dataset = get_dataset(
            tokenizer=self.tokenizer, type_path="train", args=self.hparams,
        )
        print("Training set size: {}".format(len(train_dataset)))
        dataloader = DataLoader(
            train_dataset,
            batch_size=self.hparams.train_batch_size,
            drop_last=True,
            shuffle=True,
            num_workers=4,
        )
        t_total = (
            (
                len(dataloader.dataset)
                // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu))
            )
            // self.hparams.gradient_accumulation_steps
            * float(self.hparams.num_train_epochs)
        )
        scheduler = get_linear_schedule_with_warmup(
            self.opt,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=t_total,
        )
        self.lr_scheduler = scheduler
        return dataloader

    def val_dataloader(self):
        val_dataset = get_dataset(
            tokenizer=self.tokenizer, type_path="val", args=self.hparams,
        )
        print("Validation set size: {}".format(len(val_dataset)))
        return DataLoader(
            val_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4
        )


class LoggingCallback(pl.Callback):
    def on_validation_end(self, trainer, pl_module):
        logger.info("***** Validation results *****")
        if pl_module.is_logger():
            metrics = trainer.callback_metrics
            # Log results
            for key in sorted(metrics):
                if key not in ["log", "progress_bar"]:
                    logger.info("{} = {}\n".format(key, str(metrics[key])))

    def on_test_end(self, trainer, pl_module):
        logger.info("***** Test results *****")

        if pl_module.is_logger():
            metrics = trainer.callback_metrics

            # Log and save results to file
            output_test_results_file = os.path.join(
                pl_module.hparams.output_dir, "test_results.txt"
            )
            with open(output_test_results_file, "w") as writer:
                for key in sorted(metrics):
                    if key not in ["log", "progress_bar"]:
                        logger.info("{} = {}\n".format(key, str(metrics[key])))
                        writer.write("{} = {}\n".format(key, str(metrics[key])))


class ParaphraseDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        data_dir,
        type_path,
        source_col="text1",
        target_col="text2",
        max_len=256,
    ):
        self.path = os.path.join(data_dir, type_path + ".csv")

        self.source_column = source_col
        self.target_column = target_col
        self.data = pd.read_csv(self.path)
        print(self.data.shape)

        self.max_len = max_len
        self.tokenizer = tokenizer
        self.inputs = []
        self.targets = []

        self._build()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        source_ids = self.inputs[index]["input_ids"].squeeze()
        target_ids = self.targets[index]["input_ids"].squeeze()

        src_mask = self.inputs[index][
            "attention_mask"
        ].squeeze()  # might need to squeeze
        target_mask = self.targets[index][
            "attention_mask"
        ].squeeze()  # might need to squeeze

        return {
            "source_ids": source_ids,
            "source_mask": src_mask,
            "target_ids": target_ids,
            "target_mask": target_mask,
        }

    def _build(self):
        for idx in range(len(self.data)):
            input_, target = (
                self.data.loc[idx, self.source_column],
                self.data.loc[idx, self.target_column],
            )

            input_ = "paraphrase: " + str(input_) + " </s>"
            target = str(target) + " </s>"

            # tokenize inputs
            tokenized_inputs = self.tokenizer.batch_encode_plus(
                [input_],
                max_length=self.max_len,
                pad_to_max_length=True,
                return_tensors="pt",
            )
            # tokenize targets
            tokenized_targets = self.tokenizer.batch_encode_plus(
                [target],
                max_length=self.max_len,
                pad_to_max_length=True,
                return_tensors="pt",
            )

            self.inputs.append(tokenized_inputs.data)
            self.targets.append(tokenized_targets.data)


def get_dataset(tokenizer, type_path, args):
    return ParaphraseDataset(
        tokenizer=tokenizer,
        data_dir=args.data_dir,
        type_path=type_path,
        max_len=args.max_seq_length,
    )


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


if __name__ == "__main__":
    train()

    # A zero exit code causes the job to be marked a Succeeded.
    sys.exit(0)
